import torch
from torch.utils.data import  DataLoader,Dataset

class MyDataset(Dataset):
    def __init__(self,input_data,tag_input,tag_output):
        super(MyDataset, self).__init__()
        self.len = len(input_data)
        self.input_data = input_data
        self.tag_input = tag_input
        self.tag_output = tag_output
    def __getitem__(self, id):
        return torch.LongTensor(self.input_data[id]),torch.LongTensor(self.tag_input[id]),torch.LongTensor(self.tag_output[id])

    def __len__(self):
        return self.len
